from abc import ABC, abstractmethod
import taichi as ti

from ..utils.common import Vec2, Vec3, IdGenerator, toSpherical
from ..config import MAX_COLORER, MAX_COLORER_ATTRIBUTE
from .common import WHITE, BLACK
from .space import sRGB2lRGB

__all__ = [
    "Colorer",
    "PureColor",
    "Checker",
    "CheckerSphere",
    "Grid",
    "getSpherical",
    "MixColor",
]


@ti.dataclass
class ColorerUtil:
    """材质的底层储存结构"""

    typeId: int
    attribute: ti.types.vector(MAX_COLORER_ATTRIBUTE, float)


class Colorer(ABC):
    typeId: int = None
    default: "Colorer"
    asEnd: bool = False
    data: ti.StructField = ColorerUtil.field(shape=MAX_COLORER)
    num: ti.ScalarField = ti.field(int, 1)
    used: "dict[int, type]" = {}
    getId = IdGenerator()

    def __init__(self) -> None:
        if Colorer.num[0] >= MAX_COLORER:
            raise MemoryError("number of Colorers out of design.")
        if type(self) not in self.used.values():
            type(self).typeId = next(self.getId)
            self.used[self.typeId] = type(self)
        self.index: int = Colorer.num[0]
        Colorer.num[0] += 1
        Colorer.data[self.index].typeId = self.typeId

    @abstractmethod
    def getColor(self: int, hitPoint: Vec3) -> Vec3:
        ...

    @ti.func
    def handleAllColorer(self: int, hitPoint: Vec3) -> Vec3:
        """处理所有物体的颜色"""
        typeId = Colorer.data[self].typeId
        color = Vec3(0)
        for colorer in ti.static(tuple(Colorer.used.values())):
            if typeId == colorer.typeId:
                color = colorer.getColor(self, hitPoint)
        return sRGB2lRGB(color)


class PureColor(Colorer):
    def __init__(self, r: float = None, g: float = None, b: float = None) -> None:
        super().__init__()
        Colorer.data[self.index].attribute[0] = 1 if r is None else r
        Colorer.data[self.index].attribute[1] = 1 if r is None else g
        Colorer.data[self.index].attribute[2] = 1 if r is None else b

    @ti.func
    def getColor(self: int, hitPoint: Vec3) -> Vec3:
        return Colorer.data[self].attribute[:3]


Colorer.default = PureColor()


class Checker(Colorer):
    """棋盘状着色函数"""

    def __init__(
        self,
        color1: tuple[float] = WHITE,
        color2: tuple[float] = BLACK,
        freq: tuple[float] = Vec3(1),
        offset: tuple[float] = Vec3(1e-3),
    ) -> None:
        super().__init__()
        for i in range(3):
            Colorer.data[self.index].attribute[i] = color1[i]
            Colorer.data[self.index].attribute[i + 3] = color2[i]
            Colorer.data[self.index].attribute[i + 6] = freq[i]
            Colorer.data[self.index].attribute[i + 9] = offset[i]

    @ti.func
    def getColor(self: int, hitPoint: Vec3) -> Vec3:
        color = Colorer.data[self].attribute[:3]
        freq = Colorer.data[self].attribute[6:9]
        offset = Colorer.data[self].attribute[9:12]
        if int(ti.ceil(hitPoint * freq + offset).sum()) % 2:
            color = Colorer.data[self].attribute[3:6]
        return color


class CheckerSphere(Colorer):
    """球形棋盘状着色函数"""

    def __init__(
        self,
        color1: tuple[float] = WHITE,
        color2: tuple[float] = BLACK,
        freq: tuple[float] = Vec2(1),
    ) -> None:
        super().__init__()
        for i in range(3):
            Colorer.data[self.index].attribute[i] = color1[i]
            Colorer.data[self.index].attribute[i + 3] = color2[i]
        Colorer.data[self.index].attribute[6] = freq[0]
        Colorer.data[self.index].attribute[7] = freq[1]

    @ti.func
    def getColor(self: int, hitPoint: Vec3) -> Vec3:
        x = ti.atan2(hitPoint.z, hitPoint.x)
        y = ti.atan2(hitPoint.y, (hitPoint.xz**2).sum() ** 0.5)
        color = Colorer.data[self].attribute[:3]
        freq = Colorer.data[self].attribute[6:8]
        if int(ti.ceil(Vec2(x, y) * freq).sum()) % 2:
            color = Colorer.data[self].attribute[3:6]
        return color


class Grid(Colorer):
    """网格状着色函数"""

    def __init__(
        self,
        color1: tuple[float] = WHITE,
        color2: tuple[float] = BLACK,
        freq: tuple[float] = Vec3(1),
        offset: tuple[float] = Vec3(1e-3),
        width: float = 0.02,
    ) -> None:
        super().__init__()
        for i in range(3):
            Colorer.data[self.index].attribute[i] = color1[i]
            Colorer.data[self.index].attribute[i + 3] = color2[i]
            Colorer.data[self.index].attribute[i + 6] = freq[i]
            Colorer.data[self.index].attribute[i + 9] = offset[i]
        Colorer.data[self.index].attribute[12] = width

    @ti.func
    def getColor(self: int, hitPoint: Vec3) -> Vec3:
        color = Colorer.data[self].attribute[:3]
        freq = Colorer.data[self].attribute[6:9]
        offset = Colorer.data[self].attribute[9:12]
        width = Colorer.data[self].attribute[12]
        if any(ti.math.fract(hitPoint * freq + offset + width / 2) < width):
            color = Colorer.data[self].attribute[3:6]
        return color


def getSpherical(texture):
    """根据材质生成球坐标系映射"""

    class Spherical(Colorer):
        """球坐标系映射"""

        @ti.func
        def getColor(self: int, hitPoint: Vec3) -> Vec3:
            uv = toSpherical(hitPoint)
            uv.x = 1 - uv.x
            return texture.sample(uv)

    return Spherical()


class MixColor(Colorer):
    """混合多种颜色, 主要用于调试光谱"""

    def __init__(self, *color: Vec3) -> None:
        super().__init__()
        n = len(color)
        assert 0 < n <= 3
        Colorer.data[self.index].attribute[0] = n
        for i in range(n):
            r, g, b = color[i]
            Colorer.data[self.index].attribute[i * 3 + 1] = r
            Colorer.data[self.index].attribute[i * 3 + 2] = g
            Colorer.data[self.index].attribute[i * 3 + 3] = b

    @ti.func
    def getColor(self: int, hitPoint: Vec3) -> Vec3:
        i = int(ti.random() * Colorer.data[self].attribute[0])
        ret = Vec3(0)
        if i == 0:
            ret = Colorer.data[self].attribute[1:4]
        elif i == 1:
            ret = Colorer.data[self].attribute[4:7]
        elif i == 2:
            ret = Colorer.data[self].attribute[7:10]
        return ret
